#!/usr/bin/env python

from PIL import Image

import cv2
import numpy as np
import matplotlib.pylab  as plt
from skimage import io
from skimage import morphology,draw
import time
#from auto.gui import RuKou
try:
    from auto.control import controller
except ImportError:
    from control import controller

try:
    import numpy as np
    
except ImportError:
    raise RuntimeError('cannot import numpy, make sure numpy package is installed')

class autoAgent():
    def __init__(self,processer):
        print('auto Agent start')
        #self.gui = RuKou()
        self.controllProcesser = controller()
        self.controllProcesser.setRun(speed=10)
        self.controller = None
        self.count = 0
        self.error = 0
        self.world = None
        self.processer = processer
        self.ex = False
        self.showPic = True
        self.logging = True
        self.savePic = False
        self.showTimeBool = False
        self.checkFrequency = 20
        self.picProcess = False

    def tick(self):
        self.controllProcesser.tick()
    
    def getPic(self,pic,otherSource = False):
        self.count += 1
        if not self.count % self.checkFrequency == 0 and not otherSource:
            return

        if self.picProcess:
            return
        self.picProcess = True

        print("autpAgent::getPic count:"+str(self.count))       

        image = Image.fromarray(pic)
       
        #测试时不需要转换角度
        if not otherSource:       
            image = image.transpose(Image.ROTATE_270)
        oriImage = image 

        #透视转换
        self.startTime = time.time()
        self.startTime1 = time.time()
        image = self.toushiTrans(np.asarray(image),show=False,save=False)
        self.showTime('toushi')
        self.save(image,'toushi')

        #灰度转换
        self.time = time.time()       
        image= self.color2gray(image,show=False,save=False)
        self.showTime('huidu')
        self.save(image,'huidu')

        #pooling
        #image = self.pooling(image,4,4)
        #image = cv2.pyrDown(image)
        image = cv2.pyrDown(image)
        #self.showTime('pool')
        self.save(image,'pool')

        
        #K聚类
        k_image = self.Kmeans(image,4)
        self.showTime('Kmeans')
        self.save(image,'Kmeans')

        #骨架生成
        s_image = self.skeletonGene(k_image > 0)
        self.showTime('skeletonGene')
        self.show([s_image])
        self.save(s_image,'skeletonGene')
        s_image = np.uint8(s_image)

        #print(type(s_image[0][0]),s_image.shape,s_image)           

        '''
        edges = cv2.Canny(image,50,200)
        self.show([edges])
        self.save(image,'Canny')
        print(type(edges[0][0]),edges.shape,edges)    
        '''

        #霍夫变换
        #image = np.int16(image)
        #lines = self.lines_detector_hough(image)
        #lines = cv2.HoughLinesP(s_image,2,np.pi/50,30,minLineLength=60,maxLineGap=0)
        lines = self.lines_detector_hough(s_image)
        lines = lines
        #lines = lines[:,0,:]#提取为为二维

 
        #lines = self.point2rsita(lines)
        print('there are %s line detacted'%(lines.shape[1]),lines.shape)

        #交点获取
        [xs,ys] = self.getIntersection(lines)
        self.showTime('houghImg')

        points, shift_points, cluster = train_mean_shift(data, 2)

        self.startTime = self.startTime1
        self.showTime('all')
 
        print('the postions # = %s'%len(xs))
        for i in range(len(xs)):
            print('the postion is x:[%s] y:[%s]'%(xs[i],ys[i]))
            cv2.circle(image, (int(xs[i]),int(ys[i])), 8, (255,255,255), 0)
        
        if self.showPic or len(xs) != 4:
            for theta,rho in lines.T:
                a = np.cos(theta)
                b = np.sin(theta)
                x0 = a*rho
                y0 = b*rho
                x1 = int(x0 + 10000*(-b))
                y1 = int(y0 + 10000*(a))
                x2 = int(x0 - 10000*(-b))
                y2 = int(y0 - 10000*(a))

                cv2.line(image,(x1,y1),(x2,y2),(0,0,255),2)
            
            self.show([s_image,image],True)

        
        if self.showPic or not len(xs) == 4 and False:
            print('len(xs) is more than 4,start image display')
            houghImg = self.drawLines(lines,image)
            interSectionImg = self.drawLines(lines,image,showIntersection=True)
            self.show([houghImg,interSectionImg],True)

            self.save(houghImg,'hough')
            self.save(interSectionImg,'intersection')
        
        if not len(xs) == 4:
            self.save(oriImage,'ori%s'%time.time(),True)
            self.save(image,'hough%s'%time.time(),True)
            self.exit()

        #self.show([image,self.drawLines(lines,image)])
        
        print("all is ok")


        if self.controller is not None:
            self.controller.throttle = 1.0
        else:
            print('controller is None')

        #image.save('_out/d2.png')
        if not otherSource and False:
            self.exit()

        self.picProcess = False

    def save(self,image,name = None,save = False):
        if not self.savePic and not save:
            return

        self.log('save image name:%s  imgType:%s'%(str(name),type(image)))
        if name == None:
            name = time.time()
        if type(image) == Image.Image:
            image = np.uint8(np.asarray(image))

        
        '''
        if type(image) == np.ndarray:
            self.log('the image shape is :%s'%str(image.shape))
            image = Image.fromarray(image)
        
        if image.mode != 'RGB':
            image = image.convert('RGB')
        '''
        #image.save('/out/%s.png'%name)
        print(type(image),type(image[0][0]))
        cv2.imwrite('out/%s.png'%name,image)           


    def log(self,log):
        if self.logging:
            print(log)

    def showTime(self,runFunction = "unKnow"):
        if self.showTimeBool:
            runTime = time.time() - self.startTime
            print('%s:runtime:%s'%(runFunction,runTime))
            self.startTime = time.time()        


    def lines_detector_hough(self,edge,ThetaDim = None,DistStep = None,threshold = None,halfThetaWindowSize = 2,halfDistWindowSize = None):
        '''
        :param edge: 经过边缘检测得到的二值图
        :param ThetaDim: hough空间中theta轴的刻度数量(将[0,pi)均分为多少份),反应theta轴的粒度,越大粒度越细
        :param DistStep: hough空间中dist轴的划分粒度,即dist轴的最小单位长度
        :param threshold: 投票表决认定存在直线的起始阈值
        :return: 返回检测出的所有直线的参数(theta,dist)
        @author: bilibili-会飞的吴克
        '''
        print('start hough Trans  the shape is %s'%str(edge.shape))
        imgsize = edge.shape
        if ThetaDim == None:
            ThetaDim = 90
        if DistStep == None:
            DistStep = 1
        MaxDist = np.sqrt(imgsize[0]**2 + imgsize[1]**2)
        DistDim = int(np.ceil(MaxDist/DistStep))

        if halfDistWindowSize == None:
            halfDistWindowSize = DistDim/50
        accumulator = np.zeros((ThetaDim,DistDim)) # theta的范围是[0,pi). 在这里将[0,pi)进行了线性映射.类似的,也对Dist轴进行了线性映射

        sinTheta = [np.sin(t*np.pi/ThetaDim) for t in range(ThetaDim)]
        cosTheta = [np.cos(t*np.pi/ThetaDim) for t in range(ThetaDim)]

        for i in range(imgsize[0]):
            for j in range(imgsize[1]):
                if not edge[i,j] == 0:
                    for k in range(ThetaDim):
                        accumulator[k][int(round((i*cosTheta[k]+j*sinTheta[k])*DistDim/MaxDist))] += 1

        M = accumulator.max()

        if threshold == None:
            threshold = int(M*1/10)
        result = np.array(np.where(accumulator > threshold)) # 阈值化
        
        temp = [[],[]]
        noMaxScope = 2
        for i in range(result.shape[1]):
            maxNum1 = int(max(0, result[0,i] - halfThetaWindowSize * noMaxScope + 1))
            maxNum2 = int(max(0, result[1,i] - halfDistWindowSize * noMaxScope + 1))
            minNum1 = int(min(result[0,i] + halfThetaWindowSize * noMaxScope, accumulator.shape[0]))
            minNum2 = int(min(result[1,i] + halfDistWindowSize * noMaxScope, accumulator.shape[1]))

            eight_neiborhood = accumulator[maxNum1:minNum1, maxNum2:minNum2]
            if (accumulator[result[0,i],result[1,i]] >= eight_neiborhood).all():
                temp[0].append(result[0,i])
                temp[1].append(result[1,i])
        
        result = np.array(temp)    # 非极大值抑制
        
        result = result.astype(np.float64)
        result[0] = result[0]*np.pi/ThetaDim
        result[1] = result[1]*MaxDist/DistDim
        #print("hough lines:",result)
        return result

    def noMaxRefrain(self,lines,refrainParam = 1):
        pass

    def getIntersection(self,lines):
        Cos = np.cos(lines[0])
        Sin = np.sin(lines[0])
        r = lines[1]
        #print(Cos,Sin,r)
        xs = []
        ys = []
        self.log('there is %s line to compute'%lines.shape[1])        
        for i in range(lines.shape[1]):
            for j in range(i + 1,lines.shape[1]):
                #print("i:%s j:%s"%(i,j))
                base = Cos[i] * Sin[j] - Cos[j] * Sin[i]
                if base == 0:
                    continue
                xUp = r[i] * Sin[j] - r[j] * Sin[i]
                yUp = r[j] * Cos[i] - r[i] * Cos[j]
                #print('cos[i] : %s  sin[j]:%s cos[j]:%s sin[i]:%s'%(Cos[i],Sin[j],Cos[j],Sin[i]))
                x = xUp / base
                y = yUp / base
                #print('base : %s  xUp:%s  yUp:%s x= %s  y = %s'%(base,xUp,yUp,x,y))
                if x > 1250*2 or y > 500*2 or x < 0 or y < 0:
                    print('point discard : x=%s y=%s %s %s %s %s'%(x,y,x > 1250.0*2 , y > 500.0*2 , x < 0.0 , y < 0.0))
                    continue
                xs.append(round(x))
                ys.append(round(y))
        #print(xs,ys)
        xs = np.array(xs)
        ys = np.array(ys)
        return [xs,ys]

    def drawLines(self,lines,edge,color = (255,255,255),err = 2,showIntersection = False):
        if len(edge.shape) == 2:
            result = np.dstack((edge,edge,edge))
        else:
            result = edge
        Cos = np.cos(lines[0])
        Sin = np.sin(lines[0])
        [xs,ys] = self.getIntersection(lines)

        for i in range(edge.shape[0]):
            for j in range(edge.shape[1]):
                e = np.abs(lines[1] - i*Cos - j*Sin)
                if (e < err).any():
                    result[i,j] = color

                if showIntersection:
                    e = np.square(xs - i) + np.square(ys - j)
                    if (e < err*err*4).any():
                        result[i,j] = (255,0,0)

        '''
        for i in range(len(xs)):
            result[int(xs[i]),int(ys[i])] = (255,0,0)
        '''
        return result



    def getController(self,controller):
        print("autoAgent::getController")
        self.controller = controller
        self.controllProcesser.getVechicalController(controller)

    def getWorld(self,world):
        print('autoAgent::getWorld')
        self.world=world
        self.controllProcesser.getWorld(world)

    def exit(self):
        print('autoAgent::exit')
        if self.ex:
            return

        if self.world != None:
            self.processer.exit()
        exit(0)

    def color2gray(self, img, show=False, save=False):
        gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)



        if show:
            print('gray_img::show')
            plt.subplot(121), plt.imshow(img), plt.title('Input')
            plt.subplot(122), plt.imshow(gray_img), plt.title('Output')
            plt.show()
        if save:
            print('gray_img::save')
            io.imsave('gray_img.png', gray_img)
        return gray_img
        

    def toushiTrans(self, img, show=False, save=False):
        print(('toushiTrans show:%s save:%s')%(show,save))
        pts1 = np.float32([[304, 368], [735, 370], [137, 456], [783, 456]])
        pts2 = np.float32([[100, 100], [800, 100], [100, 450], [800, 450]])
        pts2 += 200
        pts2 += [170, 4000]

        M = cv2.getPerspectiveTransform(pts1, pts2)

        dst = cv2.warpPerspective(img, M, (2000, 5000))


        if show:
            print('toushiTrans::show')
            plt.subplot(121), plt.imshow(img), plt.title('Input')
            plt.subplot(122), plt.imshow(dst), plt.title('Output')
            plt.show()
        if save:
            print('toushiTrans::save')
            io.imsave('out.png', dst)

        return dst

    def Kmeans(self,img,k):
        print('start Kmeans k = ',k)
        #centers = np.random.rand(1,k) * 255
        centers = (np.arange(0,k,1) * 255 / k).reshape(1,k)
        totalNum = img.shape[0]*img.shape[1]
        vector = img.reshape([totalNum,1]).copy()
        master = np.zeros([totalNum,1])
        iters = 0
        loss = 1
        cost = 99999999999999999
        while iters < 100 and not loss == 0:
            iters += 1
            dist = np.abs(centers - vector)
            master = np.argmin(dist,axis=1)
            lastCost = cost
            cost = np.sum(dist)
            loss = lastCost - cost


            for i in range(k):
                indexs = np.array(list(np.where(master == i)))
                indexs = indexs.reshape(indexs.shape[1])
                belongPoints = vector[indexs]
                centers[0,i] = np.mean(belongPoints)

            #print('iter:%s cost:%s loss:%s'%(iters,cost,loss))
        
        needIndex = np.argmax(centers,axis=1)[0]
        #print(needIndex.shape,'  needIndex')
        imgs = []
        for i in range(4):
            indexs = np.array(list(np.where(master == i)))
            vector = np.zeros([totalNum,1])
            vector[indexs] = 255
            imgBuffer = vector.reshape(img.shape)
            imgs.append(imgBuffer)

        img0 = imgs[0]
        imgs[0] = imgs[needIndex]
        imgs[needIndex] = img0

        self.show(imgs)




        
        return imgs[0]

    def show(self,images,show = False):
        if show or self.showPic:       
            print('show picture num = %s'%len(images))
            if len(images) == 2:
                plt.subplot(121), plt.imshow(images[0]), plt.title('Input')
                plt.subplot(122), plt.imshow(images[1]), plt.title('Output')
            elif len(images) == 4:
                plt.subplot(221), plt.imshow(images[0]), plt.title('0')
                plt.subplot(222), plt.imshow(images[1]), plt.title('1')            
                plt.subplot(223), plt.imshow(images[2]), plt.title('2')
                plt.subplot(224), plt.imshow(images[3]), plt.title('3')
            elif len(images) == 1:
                plt.imshow(images[0])
            else:
                print('can`t show  number is wrong')
                return           
            plt.show()

    def pooling(self,inputMap,poolSize=3,poolStride=2,mode='max'):
        """INPUTS:
                inputMap - input array of the pooling layer
                poolSize - X-size(equivalent to Y-size) of receptive field
                poolStride - the stride size between successive pooling squares
        
        OUTPUTS:
                outputMap - output array of the pooling layer
                
        Padding mode - 'edge'
        """
        # inputMap sizes
        in_row,in_col = np.shape(inputMap)
        
        # outputMap sizes
        out_row,out_col = int(np.floor(in_row/poolStride)),int(np.floor(in_col/poolStride))
        row_remainder,col_remainder = np.mod(in_row,poolStride),np.mod(in_col,poolStride)
        if row_remainder != 0:
            out_row +=1
        if col_remainder != 0:
            out_col +=1
        outputMap = np.zeros((out_row,out_col))
        
        # padding
        temp_map = np.lib.pad(inputMap, ((0,poolSize-row_remainder),(0,poolSize-col_remainder)), 'edge')
        
        # max pooling
        for r_idx in range(0,out_row):
            for c_idx in range(0,out_col):
                startX = c_idx * poolStride
                startY = r_idx * poolStride
                poolField = temp_map[startY:startY + poolSize, startX:startX + poolSize]
                poolOut = np.max(poolField)
                outputMap[r_idx,c_idx] = poolOut
        
        #self.show([inputMap,outputMap])

        # retrun outputMap
        return  outputMap

    def skeletonGene(self,img):
        skeleton =morphology.skeletonize(img)
        self.show([img,skeleton])
        return skeleton

    def point2rsita(self,lines):
        outlines = []
        for line in lines:
            r = 0
            sita = 0
            if line[0] == line[2]:
                sita = 0
                r = line[0]
            elif line[1] == line[3]:
                sita = np.pi / 2
                r = line[1]
            else:
                sita = np.arctan((line[2] - line[0])/(line[3] - line[1]))
                h1 = line[0] / np.tan(sita)
                h = h1 + line[1]
                r = h * np.sin(sita)
            outlines.append([r,sita])
        return np.array(outlines)





        


    